home *** CD-ROM | disk | FTP | other *** search
/ HPAVC / HPAVC CD-ROM.iso / BPNN133U.ZIP / BPTRAIN.C < prev    next >
C/C++ Source or Header  |  1992-11-19  |  6KB  |  274 lines

  1. /*
  2. *-----------------------------------------------------------------------------
  3. *    file:    bptrain.c
  4. *    desc:    back propagation Multi Layer Perceptron (MLP) training
  5. *    by:    patrick ko
  6. *    date:    02 aug 1991
  7. *    revi:    v1.32u 26 apr 1992
  8. *    revi:    v1.33u 19 nov 1992 - cparser.c bug fixed
  9. *-----------------------------------------------------------------------------
  10. */
  11.  
  12. #include <stdio.h>
  13. #include <stdlib.h>
  14. #ifdef __TURBOC__
  15. #include <mem.h>
  16. #include <alloc.h>
  17. #endif
  18.  
  19. #include "nntype.h"
  20. #include "nncreat.h"
  21. #include "nntrain.h"
  22. #include "nnerror.h"
  23. #include "cparser.h"
  24. #include "bptrainv.h"
  25. #include "timer.h"
  26.  
  27. #define MAXHIDDEN    128
  28.  
  29. static INTEGER    hiddencnt = 0;
  30. static INTEGER    hidden[MAXHIDDEN];
  31. static INTEGER    output;
  32. static INTEGER    input;
  33. static INTEGER    totalhidden;
  34. static INTEGER    totalpatt = 0;
  35. static REAL    trainerr = ERROR_DEFAULT;
  36. static INTEGER    report = 0;
  37. static INTEGER    timer = 0;
  38. static long int    tdump = 0;
  39.  
  40. static VECTOR    **inputvect;
  41. static VECTOR    **targtvect;
  42.  
  43. extern REAL    TOLER;
  44.  
  45. static char    tname[128];
  46. /*
  47. *    dump file name with default
  48. */
  49. static char    dname[128] = "bptrain.dmp";
  50. static char    dinname[128] = "";
  51.  
  52. int    usage( )
  53. {
  54.     printf( "%s %s - by %s\n", PROGNAME, VERSION, AUTHOR );
  55.     printf( "Copyright (c) 1992 All Rights Reserved. %s\n\n", DATE );
  56.     printf( "Description: backprop neural net training with adaptive coefficients\n");
  57.     printf( "Usage: %s @file -i=# -o=# -hh=# {-h=#} -samp=# -ftrain=<fn>\n", PROGNAME);
  58.     printf( "[-fdump=<fn>] [-fdumpin=<fn>] -r=# [-t] [-tdump=#] [-w+=# -w-=#]\n" );
  59.     printf( "[-err=] [-torerr=] [// ...]\n");
  60.     printf( "Example: " );
  61.     printf( "create and train a 2x4x3x1 dimension NN with 10 samples\n");
  62.     printf( "%s -i=2 -o=1 -hh=2 -h=4 -h=3 -err=0.01 ", PROGNAME );
  63.     printf( "-ftrain=input.trn -samp=10\n" );
  64.     printf( "Where:\n" );
  65.     printf( "-i=,-o=     dimension of input/output layer\n" );
  66.     printf( "-hh=        number of hidden layers\n" );
  67.     printf( "-h=         each hidden layer dimension (may be multiple)\n" );
  68.     printf( "-ftrain=    name of train file containing inputs and targets\n" );
  69.     printf( "-fdump=     name of output weights dump file\n" );
  70.     printf( "-fdumpin=   name of input weights dump file (if any)\n");
  71.     printf( "-samp=      number of train input patterns in train file\n" );
  72.     printf( "-r=         report training status interval\n" );
  73.     printf( "-t          time the training (good for non-Unix)\n" );
  74.     printf( "-tdump=     time for periodic dump (specify seconds)\n");
  75.     printf( "-w+=        initial random weight upper bound\n" );
  76.     printf( "-w-=        initial random weight lower bound\n" );
  77.     printf( "-err=       mean square per unit train error ");
  78.     printf( "(def=%f)\n", ERROR_DEFAULT );
  79.     printf( "-torerr=    tolerance error (def=%f)\n", TOLER_DEFAULT);
  80.     exit (0);
  81. }
  82.  
  83. int    parse( )
  84. {
  85.     int    cmd;
  86.     char    rest[128];
  87.     int    resti;
  88.     long    restl;
  89.  
  90.     while ((cmd = cmdget( rest ))!= -1)
  91.         {
  92.         resti = atoi(rest);
  93.         restl = atol(rest);
  94.         switch (cmd)
  95.             {
  96.             case CMD_DIMINPUT:
  97.                 input = resti; break;
  98.             case CMD_DIMOUTPUT:
  99.                 output = resti; break;
  100.             case CMD_DIMHIDDENY:
  101.                 if (input <= 0 || output <= 0)
  102.                     {
  103.                     error( NNIOLAYER );
  104.                     }
  105.                 if (resti > MAXHIDDEN)
  106.                     {
  107.                     error( NN2MANYLAYER );
  108.                     }
  109.                 totalhidden = resti; break;
  110.             case CMD_DIMHIDDEN:
  111.                 if (hiddencnt >= totalhidden)
  112.                     {
  113.                     /*
  114.                     * hidden layers more than specified
  115.                     */
  116.                     break;
  117.                     }
  118.                 hidden[hiddencnt++] = resti;
  119.                 break;
  120.             case CMD_TRAINFILE:
  121.                 strcpy( tname, rest );
  122.                 break;
  123.             case CMD_TOTALPATT:
  124.                 totalpatt = resti;
  125.                 break;
  126.             case CMD_DUMPFILE:
  127.                 strcpy( dname, rest );
  128.                 break;
  129.             case CMD_DUMPIN:
  130.                 strcpy( dinname, rest );
  131.                 break;
  132.             case CMD_TRAINERR:
  133.                 trainerr = atof( rest );
  134.                 break;
  135.             case CMD_TOLER:
  136.                 TOLER = atof( rest );
  137.                 break;
  138.             case CMD_REPORT:
  139.                 report = resti;
  140.                 break;
  141.             case CMD_TIMER:
  142.                 timer = 1;
  143.                 break;
  144.             case CMD_TDUMP:
  145.                 tdump = restl;
  146.                 break;
  147.             case CMD_WPOS:
  148.                 UB = atof(rest);
  149.                 break;
  150.             case CMD_WNEG:
  151.                 LB = atof(rest);
  152.                 break;
  153.             case CMD_COMMENT:
  154.                 break;
  155.             case CMD_NULL:
  156.                 printf( "%s: unknown command [%s]\n", PROGNAME, rest );
  157.                 exit (2);
  158.                 break;
  159.             }
  160.         }
  161.         if (hiddencnt < totalhidden)
  162.             {
  163.             error( NN2MANYHIDDEN );
  164.             }
  165. }
  166.  
  167. int    gettrainvect( tname )
  168. char    *tname;
  169. {
  170.     int    i, j, cnt;
  171.     VECTOR    *tmp;
  172.     FILE    *ft;
  173.  
  174.  
  175.     ft = fopen( tname, "r" );
  176.     if (ft == NULL)
  177.         {
  178.         error( NNTFRERR );
  179.         }
  180.  
  181.     inputvect = (VECTOR **) malloc( sizeof(VECTOR *) * totalpatt );
  182.     targtvect = (VECTOR **) malloc( sizeof(VECTOR *) * totalpatt );
  183.  
  184.     if (totalpatt <= 0)
  185.         {
  186.         error( NN2FEWPATT );
  187.         }
  188.     for (i=0; i<totalpatt; i++)
  189.         {
  190.         /*
  191.         *    allocate input patterns
  192.         */
  193.         tmp = v_creat( input );
  194.         for (j=0; j<input; j++)
  195.             {
  196.             cnt = fscanf( ft, "%lf", &tmp->vect[j] );
  197.             if (cnt < 1)
  198.                 {
  199.                 error( NNTFIERR );
  200.                 }
  201.             }
  202.         *(inputvect + i) = tmp;
  203.  
  204.         tmp = v_creat( output );
  205.         for (j=0; j<output; j++)
  206.             {
  207.             cnt = fscanf( ft, "%lf", &tmp->vect[j] );
  208.             if (cnt < 1)
  209.                 {
  210.                 error( NNTFIERR );
  211.                 }
  212.             }
  213.         *(targtvect + i) = tmp;
  214.         }
  215.     fclose( ft );
  216. }
  217.  
  218.  
  219. int    main( argc, argv )
  220. int    argc;
  221. char    ** argv;
  222. {
  223.     NET    *nn;
  224.     FILE    *fdump;
  225.  
  226.     if (argc < 2)
  227.         {
  228.         usage();
  229.         }
  230.     else
  231.         {
  232.         cmdinit( argc, argv );
  233.         parse();
  234.         }
  235.  
  236.     /* create a neural net */
  237.     nn = nn_creat( totalhidden + 1, input, output, hidden );
  238.  
  239.     gettrainvect( tname );
  240.  
  241.     /* read last dump, if any */
  242.     if (*dinname != NULL)
  243.         {
  244.         printf( "%s: opening dump file [%s] ...\n", PROGNAME, dinname);
  245.         if ((fdump = fopen( dinname, "r" )) != NULL)
  246.             {
  247.             nn_load( fdump, nn );
  248.             fclose( fdump );
  249.             }
  250.         }
  251.  
  252.     printf( "%s: start\n", PROGNAME );
  253.  
  254.     if (timer)
  255.         timer_restart();
  256.     /*
  257.     * the default training error, ..., etc can be incorporated into
  258.     * the interface - if you like.
  259.     */
  260.     nnbp_train( nn, inputvect, targtvect, totalpatt,
  261.     trainerr, ETA_DEFAULT, ALPHA_DEFAULT, report, tdump, dname );
  262.  
  263.     if (timer)
  264.         printf("%s: time elapsed = %ld secs\n", PROGNAME, timer_stop());
  265.  
  266.     printf( "%s: dump neural net to [%s]\n", PROGNAME, dname );
  267.  
  268.     fdump = fopen( dname, "w" );
  269.     nn_dump( fdump, nn );
  270.     fclose(fdump);
  271. }
  272.  
  273.  
  274.